import json
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union

# Avoid circular imports
if TYPE_CHECKING:
    from mirix.schemas.message import Message


class ErrorCode(Enum):
    """Enum for error codes used by client."""

    INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR"
    CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED"
    RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"


class MirixError(Exception):
    """Base class for all Mirix related errors."""

    def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}):
        self.message = message
        self.code = code
        self.details = details
        super().__init__(message)

    def __str__(self) -> str:
        if self.code:
            return f"{self.code.value}: {self.message}"
        return self.message

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(message='{self.message}', code='{self.code}', details={self.details})"


class MirixToolCreateError(MirixError):
    """Error raised when a tool cannot be created."""

    default_error_message = "Error creating tool."

    def __init__(self, message=None):
        super().__init__(message=message or self.default_error_message)


class MirixConfigurationError(MirixError):
    """Error raised when there are configuration-related issues."""

    def __init__(self, message: str, missing_fields: Optional[List[str]] = None):
        self.missing_fields = missing_fields or []
        super().__init__(message=message, details={"missing_fields": self.missing_fields})


class MirixAgentNotFoundError(MirixError):
    """Error raised when an agent is not found."""


class MirixUserNotFoundError(MirixError):
    """Error raised when a user is not found."""


class LLMError(MirixError):
    pass


class LLMAuthenticationError(LLMError):
    """Error raised when LLM authentication fails."""
    pass


class LLMBadRequestError(LLMError):
    """Error raised when LLM request is malformed."""
    pass


class LLMConnectionError(LLMError):
    """Error raised when LLM connection fails."""
    pass


class LLMNotFoundError(LLMError):
    """Error raised when LLM resource is not found."""
    pass


class LLMPermissionDeniedError(LLMError):
    """Error raised when LLM permission is denied."""
    pass


class LLMRateLimitError(LLMError):
    """Error raised when LLM rate limit is exceeded."""
    pass


class LLMServerError(LLMError):
    """Error raised when LLM server encounters an error."""
    pass


class LLMUnprocessableEntityError(LLMError):
    """Error raised when LLM cannot process the entity."""
    pass


class BedrockPermissionError(MirixError):
    """Exception raised for errors in the Bedrock permission process."""

    def __init__(self, message="User does not have access to the Bedrock model with the specified ID."):
        super().__init__(message=message)


class BedrockError(MirixError):
    """Exception raised for errors in the Bedrock process."""

    def __init__(self, message="Error with Bedrock model."):
        super().__init__(message=message)


class LLMJSONParsingError(MirixError):
    """Exception raised for errors in the JSON parsing process."""

    def __init__(self, message="Error parsing JSON generated by LLM"):
        super().__init__(message=message)


class LocalLLMError(MirixError):
    """Generic catch-all error for local LLM problems"""

    def __init__(self, message="Encountered an error while running local LLM"):
        super().__init__(message=message)


class LocalLLMConnectionError(MirixError):
    """Error for when local LLM cannot be reached with provided IP/port"""

    def __init__(self, message="Could not connect to local LLM"):
        super().__init__(message=message)


class ContextWindowExceededError(MirixError):
    """Error raised when the context window is exceeded but further summarization fails."""

    def __init__(self, message: str, details: dict = {}):
        error_message = f"{message} ({details})"
        super().__init__(
            message=error_message,
            code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
            details=details,
        )


class RateLimitExceededError(MirixError):
    """Error raised when the llm rate limiter throttles api requests."""

    def __init__(self, message: str, max_retries: int):
        error_message = f"{message} ({max_retries})"
        super().__init__(
            message=error_message,
            code=ErrorCode.RATE_LIMIT_EXCEEDED,
            details={"max_retries": max_retries},
        )


class MirixMessageError(MirixError):
    """Base error class for handling message-related errors."""

    messages: List[Union["Message", "MirixMessage"]]
    default_error_message: str = "An error occurred with the message."

    def __init__(self, *, messages: List[Union["Message", "MirixMessage"]], explanation: Optional[str] = None) -> None:
        error_msg = self.construct_error_message(messages, self.default_error_message, explanation)
        super().__init__(error_msg)
        self.messages = messages

    @staticmethod
    def construct_error_message(messages: List[Union["Message", "MirixMessage"]], error_msg: str, explanation: Optional[str] = None) -> str:
        """Helper method to construct a clean and formatted error message."""
        if explanation:
            error_msg += f" (Explanation: {explanation})"

        # Pretty print out message JSON
        message_json = json.dumps([message.model_dump() for message in messages], indent=4)
        return f"{error_msg}\n\n{message_json}"


class MissingToolCallError(MirixMessageError):
    """Error raised when a message is missing a tool call."""

    default_error_message = "The message is missing a tool call."


class InvalidToolCallError(MirixMessageError):
    """Error raised when a message uses an invalid tool call."""

    default_error_message = "The message uses an invalid tool call or has improper usage of a tool call."


class MissingInnerMonologueError(MirixMessageError):
    """Error raised when a message is missing an inner monologue."""

    default_error_message = "The message is missing an inner monologue."


class InvalidInnerMonologueError(MirixMessageError):
    """Error raised when a message has a malformed inner monologue."""

    default_error_message = "The message has a malformed inner monologue."
